Binary Classification using a deep neural network with ‘L’ hidden layers.
Notation:
Train dataset.
Sepal.Length <dbl> | Sepal.Width <dbl> | Species <int> | ||
---|---|---|---|---|
5.7 | 2.6 | 1 | ||
4.5 | 2.3 | 0 | ||
5.0 | 2.3 | 1 | ||
5.1 | 3.8 | 0 | ||
5.6 | 2.7 | 1 | ||
5.6 | 2.5 | 1 | ||
5.0 | 3.6 | 0 | ||
5.4 | 3.0 | 1 | ||
5.0 | 3.0 | 0 | ||
5.4 | 3.4 | 0 |
Test dataset.
Sepal.Length <dbl> | Sepal.Width <dbl> | Species <int> | ||
---|---|---|---|---|
5.1 | 3.3 | 0 | ||
5.4 | 3.4 | 0 | ||
5.3 | 3.7 | 0 | ||
5.0 | 3.4 | 0 | ||
6.1 | 3.0 | 1 | ||
4.9 | 2.4 | 1 | ||
5.5 | 2.6 | 1 | ||
4.3 | 3.0 | 0 | ||
5.1 | 3.8 | 0 | ||
5.4 | 3.9 | 0 |
Plot the data.
Scale all the values between 0 and 1.
X_train <- iris %>%
select(Sepal.Length, Sepal.Width) %>%
mutate_all(function(x) scale(x))
y_train <- iris %>%
select(Species)
X_test <- test %>%
select(Sepal.Length, Sepal.Width) %>%
mutate_all(function(x) scale(x))
y_test <- test %>%
select(Species)
## Shape of X (row, column):
## 80 2
## Shape of y (row, column) :
## 80 1
## Number of training samples:
## 80
## Shape of X_test (row, column):
## 20 2
## Shape of y_test (row, column) :
## 20 1
## Number of testing samples:
## 20
Convert input and output to matrices. Change the shape of X and y by taking its transpose. This will make the matrix calculations slightly less verbose.
The model’s structure is: LINEAR -> RELU -> LINEAR -> SIGMOID.
initializeParametersShallow <- function(n_x, n_h, n_y){
W1 <- matrix(runif(n_h * n_x), nrow = n_h, ncol = n_x, byrow = TRUE)
b1 <- matrix(rep(0, n_h), nrow = n_h, ncol = 1)
W2 <- matrix(runif(n_y * n_h), nrow = n_y, ncol = n_h, byrow = TRUE)
b2 <- matrix(rep(0, n_y), nrow = n_y, ncol = 1)
params <- list("W1" = W1,
"b1" = b1,
"W2" = W2,
"b2" = b2)
return (params)
}
## $W1
## [1] 4 2
##
## $b1
## [1] 4 1
##
## $W2
## [1] 1 4
##
## $b2
## [1] 1 1
The model’s structure is: LINEAR -> RELU -> LINEAR -> RELU -> LINEAR -> RELU -> LINEAR -> SIGMOID.
That is, 2 -> 3 -> 4 -> 1.
initializeParameters <- function(layer_dims){
L = length(layer_dims)
params <- list()
for (i in 2:L) {
w <- paste("W", i-1, sep="")
w_mat <- matrix(runif(layer_dims[[i]] * layer_dims[[i-1]]), nrow = layer_dims[[i]], ncol = layer_dims[[i-1]], byrow = TRUE)
params[[w]] <- w_mat
b <- paste("b", i-1, sep="")
b_mat <- matrix(rep(0, layer_dims[[i]]), nrow = layer_dims[[i]], byrow = TRUE)
params[[b]] <- b_mat
}
return (params)
}
## $W1
## [1] 3 2
##
## $b1
## [1] 3 1
##
## $W2
## [1] 4 3
##
## $b2
## [1] 4 1
##
## $W3
## [1] 1 4
##
## $b3
## [1] 1 1
We’ll take sample architecture to test our functions as 2 -> 3 -> 4 -> 1, where 2 is the number of input neurons and 1 is number of output neurons.
We will write three functions:
Define the sigmoid function.
sigmoid <- function(Z){
A <- 1 / (1 + exp(-Z))
cache <- Z
out <- list("A" = A,
"cache" = cache)
return (out)
}
Define the relu function.
relu <- function(Z){
if (Z >= 0) {
A <- Z
cache <- Z
out <- list("A" = A,
"cache" = cache)
return (out)
}
else {
A <- 0
cache <- Z
out <- list("A" = A,
"cache" = cache)
return (out)
}
}
The equation for forward propagation is:
where .
linearForward <- function(A, W, b) {
m <- dim(A)[2]
b_new <- matrix(rep(b, m), nrow = dim(W)[1])
Z <- W %*% A + b_new
cache <- list("A" = A,
"W" = W,
"b" = b)
out <- list("Z" = Z,
"cache" = cache)
return (out)
}
linear_forward1 <- linearForward(X_train, params$W1, params$b1)
linear_forward2 <- linearForward(linear_forward1$Z, params$W2, params$b2)
linear_forward3 <- linearForward(linear_forward2$Z, params$W3, params$b3)
## Shape of Z1:
## 3 80
##
## linear_cache:
## $A
## [1] 2 80
##
## $W
## [1] 3 2
##
## $b
## [1] 3 1
## Shape of Z2:
## 4 80
##
## linear_cache:
## $A
## [1] 3 80
##
## $W
## [1] 4 3
##
## $b
## [1] 4 1
## Shape of Z3:
## 1 80
##
## linear_cache:
## $A
## [1] 4 80
##
## $W
## [1] 1 4
##
## $b
## [1] 1 1
Now we will implement the forward propagation of the LINEAR->ACTIVATION layer.
Mathematical relation is: where the activation “g” can be sigmoid() or relu().
linearActivationForward <- function(A_prev, W, b, activation) {
linear_forward <- linearForward(A_prev, W, b)
Z <- linear_forward$Z
linear_cache <- linear_forward$cache
if (activation == "sigmoid") {
A <- apply(X = Z, MARGIN = c(1, 2), function(x) sigmoid(x)$A)
activation_cache <- apply(X = Z, MARGIN = c(1, 2), function(x) sigmoid(x)$A)
cache <- list("linear_cache" = linear_cache,
"activation_cache" = activation_cache)
out <- list("A" = A,
"cache" = cache)
return (out)
}
else if (activation == "relu") {
A <- apply(X = Z, MARGIN = c(1, 2), function(x) relu(x)$A)
activation_cache <- apply(X = Z, MARGIN = c(1, 2), function(x) relu(x)$A)
cache <- list("linear_cache" = linear_cache,
"activation_cache" = activation_cache)
out <- list("A" = A,
"cache" = cache)
return (out)
}
}
linear_activation_forward1 <- linearActivationForward(A_prev = X_train, W = params$W1, b = params$b1, activation = "relu")
linear_activation_forward2 <- linearActivationForward(A_prev = linear_activation_forward1$A, W = params$W2, b = params$b2, activation = "relu")
linear_activation_forward3 <- linearActivationForward(A_prev = linear_activation_forward2$A, W = params$W3, b = params$b3, activation = "relu")
# linear_activation_forward1
cat("Shape of A1: \n", dim(linear_activation_forward1$A), "\n\n", "Linear Cache: \n")
## Shape of A1:
## 3 80
##
## Linear Cache:
## $A
## [1] 2 80
##
## $W
## [1] 3 2
##
## $b
## [1] 3 1
##
## Activation Cache:
## [1] 3 80
# linear_activation_forward2
cat("Shape of A2: \n", dim(linear_activation_forward2$A), "\n\n", "Linear Cache: \n")
## Shape of A2:
## 4 80
##
## Linear Cache:
## $A
## [1] 3 80
##
## $W
## [1] 4 3
##
## $b
## [1] 4 1
##
## Activation Cache:
## [1] 4 80
# linear_activation_forward3
cat("Shape of A3: \n", dim(linear_activation_forward3$A), "\n\n", "Linear Cache: \n")
## Shape of A3:
## 1 80
##
## Linear Cache:
## $A
## [1] 4 80
##
## $W
## [1] 1 4
##
## $b
## [1] 1 1
##
## Activation Cache:
## [1] 1 80
LModelForward <- function(X, parameters) {
caches = list()
A = X
L = floor(length(parameters)/2)
A_prev <- A
for (i in 1:(L-1)) {
W <- paste("W", i, sep = "")
b <- paste("b", i, sep = "")
linear_activation_forward <- linearActivationForward(A_prev, parameters[[W]], parameters[[b]], activation = "relu")
A <- linear_activation_forward$A
cache <- linear_activation_forward$cache
caches <- append(caches, cache)
A_prev <- A
}
WL <- paste("W", L, sep = "")
bL <- paste("b", L, sep = "")
AL <- linearActivationForward(A_prev, parameters[[WL]], parameters[[bL]], activation = "sigmoid")$A
cache <- linearActivationForward(A_prev, parameters[[WL]], parameters[[bL]], activation = "sigmoid")$cache
caches <- append(caches, cache)
out <- list("AL" = AL,
"caches" = caches)
return (out)
}
## $linear_cache
## $linear_cache$A
## [,1] [,2] [,3] [,4] [,5] [,6]
## Sepal.Length 0.2982138 -1.548788 -0.7792037 -0.6252869 0.1442970 0.144297
## Sepal.Width -1.0569860 -1.678742 -1.6787425 1.4300399 -0.8497339 -1.264238
## [,7] [,8] [,9] [,10] [,11] [,12]
## Sepal.Length -0.7792037 -0.1635366 -0.7792037 -0.1635366 1.5295480 1.0677977
## Sepal.Width 1.0155356 -0.2279774 -0.2279774 0.6010313 -0.6424817 -0.4352295
## [,13] [,14] [,15] [,16] [,17] [,18]
## Sepal.Length 0.4521305 -0.7792037 1.6834648 0.7599641 0.9138809 1.221714
## Sepal.Width -0.8497339 0.8082834 -0.2279774 -0.4352295 -0.6424817 -1.678742
## [,19] [,20] [,21] [,22] [,23]
## Sepal.Length -0.6252869 0.2982138 0.9138809 -0.009619799 -1.08703727
## Sepal.Width 1.4300399 -0.4352295 -0.6424817 -1.471490339 -0.02072522
## [,24] [,25] [,26] [,27] [,28] [,29]
## Sepal.Length -0.4713701 0.7599641 1.8373816 0.7599641 -1.7027044 -1.0870373
## Sepal.Width 0.8082834 -0.8497339 -0.2279774 -1.8859947 0.1865269 0.6010313
## [,30] [,31] [,32] [,33] [,34]
## Sepal.Length -0.009619799 -1.39487084 -0.7792037 -0.1635366 1.9912984
## Sepal.Width 2.259048549 -0.02072522 -2.3004990 1.6372921 -0.6424817
## [,35] [,36] [,37] [,38] [,39] [,40]
## Sepal.Length -1.0870373 -1.7027044 0.2982138 -0.4713701 0.4521305 -0.9331205
## Sepal.Width 0.6010313 -0.2279774 -0.2279774 0.6010313 1.8445442 1.0155356
## [,41] [,42] [,43] [,44] [,45] [,46]
## Sepal.Length -0.7792037 -0.9331205 2.2991319 -0.6252869 -1.3948708 0.1442970
## Sepal.Width 0.1865269 -0.2279774 0.1865269 -1.2642382 0.6010313 -0.2279774
## [,47] [,48] [,49] [,50] [,51] [,52]
## Sepal.Length 1.3756312 1.83738159 1.067798 -0.7792037 -0.1635366 -1.0870373
## Sepal.Width -0.4352295 -0.02072522 -1.885995 0.8082834 1.2227877 -0.2279774
## [,53] [,54] [,55] [,56] [,57]
## Sepal.Length -0.7792037 -0.7792037 -0.6252869 0.9138809 0.4521305
## Sepal.Width 0.6010313 0.3937791 0.8082834 -0.4352295 -0.8497339
## [,58] [,59] [,60] [,61] [,62]
## Sepal.Length -0.009619799 -0.93312049 -1.2409541 1.2217145 -1.3948708
## Sepal.Width 0.808283426 -0.02072522 0.1865269 0.3937791 0.1865269
## [,63] [,64] [,65] [,66] [,67]
## Sepal.Length 1.83738159 0.1442970 0.6060473 0.1442970 -0.009619799
## Sepal.Width -0.02072522 -0.4352295 -0.2279774 -0.2279774 -1.264238179
## [,68] [,69] [,70] [,71] [,72] [,73]
## Sepal.Length -0.4713701 0.2982138 0.4521305 0.6060473 -0.6252869 -1.0870373
## Sepal.Width 2.0517964 2.6735529 -1.0569860 0.1865269 0.8082834 -0.2279774
## [,74] [,75] [,76] [,77] [,78] [,79]
## Sepal.Length -0.4713701 2.14521515 1.6834648 -0.6252869 0.2982138 -0.6252869
## Sepal.Width -0.8497339 -0.02072522 -0.4352295 0.6010313 1.4300399 1.2227877
## [,80]
## Sepal.Length -0.4713701
## Sepal.Width 1.0155356
##
## $linear_cache$W
## [,1] [,2]
## [1,] 0.03907962 0.5702809
## [2,] 0.29190193 0.8887148
## [3,] 0.61510693 0.9397041
##
## $linear_cache$b
## [,1]
## [1,] 0
## [2,] 0
## [3,] 0
##
##
## $activation_cache
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
## [1,] 0 0 0 0.7910885 0 0 0.5486896 0 0 0.3363657 0.0000000
## [2,] 0 0 0 1.0883752 0 0 0.6750705 0 0 0.4864088 0.0000000
## [3,] 0 0 0 0.9591960 0 0 0.4750093 0 0 0.4641991 0.3370929
## [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20]
## [1,] 0.0000000 0 0.4304976 0.0000000 0.00000000 0 0 0.7910885 0
## [2,] 0.0000000 0 0.4908824 0.2887997 0.00000000 0 0 1.0883752 0
## [3,] 0.2478228 0 0.2802536 0.8212796 0.05847222 0 0 0.9591960 0
## [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28] [,29]
## [1,] 0 0 0 0.4425276 0 0.0000000 0 0.03983171 0.3002757
## [2,] 0 0 0 0.5807396 0 0.3337284 0 0.00000000 0.2168371
## [3,] 0 0 0 0.4696042 0 0.9159549 0 0.00000000 0.0000000
## [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38]
## [1,] 1.287916 0 0 0.9273255 0.00000000 0.3002757 0 0 0.3243357
## [2,] 2.004842 0 0 1.4073491 0.01028082 0.2168371 0 0 0.3965515
## [3,] 2.116920 0 0 1.4379776 0.62111875 0.0000000 0 0 0.2748485
## [,39] [,40] [,41] [,42] [,43] [,44] [,45] [,46]
## [1,] 1.069577 0.5426746 0.07592177 0 0.1962220 0 0.2882456 0
## [2,] 1.771252 0.6301419 0.00000000 0 0.8368903 0 0.1269799 0
## [3,] 2.011434 0.3803341 0.00000000 0 1.5894921 0 0.0000000 0
## [,47] [,48] [,49] [,50] [,51] [,52] [,53] [,54]
## [1,] 0.00000000 0.05998498 0 0.4304976 0.6909416 0 0.31230567 0.1941137
## [2,] 0.01475447 0.51791642 0 0.4908824 1.0389730 0 0.30669433 0.1225063
## [3,] 0.43717333 1.11071058 0 0.2802536 1.0484662 0 0.08549794 0.0000000
## [,55] [,56] [,57] [,58] [,59] [,60] [,61] [,62]
## [1,] 0.4365126 0.0000000 0 0.4605727 0 0.05787674 0.2723088 0.05186173
## [2,] 0.5358110 0.0000000 0 0.7155254 0 0.00000000 0.7065781 0.00000000
## [3,] 0.3749289 0.1531475 0 0.7536300 0 0.00000000 1.1215209 0.00000000
## [,63] [,64] [,65] [,66] [,67] [,68] [,69] [,70] [,71]
## [1,] 0.05998498 0 0.0000000 0 0 1.151679 1.536330 0 0.1300569
## [2,] 0.51791642 0 0.0000000 0 0 1.685868 2.463075 0 0.3426756
## [3,] 1.11071058 0 0.1585526 0 0 1.638138 2.695782 0 0.5480640
## [,72] [,73] [,74] [,75] [,76] [,77] [,78] [,79]
## [1,] 0.4365126 0 0 0.0720150 0.0000000 0.3183207 0.8271785 0.6728965
## [2,] 0.5358110 0 0 0.6077736 0.1046117 0.3516229 1.3579468 0.9041871
## [3,] 0.3749289 0 0 1.3000611 0.6265239 0.1801732 1.5272477 0.7644403
## [,80]
## [1,] 0.5607196
## [2,] 0.7649277
## [3,] 0.6643599
##
## $linear_cache
## $linear_cache$A
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
## [1,] 0 0 0 0.7910885 0 0 0.5486896 0 0 0.3363657 0.0000000
## [2,] 0 0 0 1.0883752 0 0 0.6750705 0 0 0.4864088 0.0000000
## [3,] 0 0 0 0.9591960 0 0 0.4750093 0 0 0.4641991 0.3370929
## [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20]
## [1,] 0.0000000 0 0.4304976 0.0000000 0.00000000 0 0 0.7910885 0
## [2,] 0.0000000 0 0.4908824 0.2887997 0.00000000 0 0 1.0883752 0
## [3,] 0.2478228 0 0.2802536 0.8212796 0.05847222 0 0 0.9591960 0
## [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28] [,29]
## [1,] 0 0 0 0.4425276 0 0.0000000 0 0.03983171 0.3002757
## [2,] 0 0 0 0.5807396 0 0.3337284 0 0.00000000 0.2168371
## [3,] 0 0 0 0.4696042 0 0.9159549 0 0.00000000 0.0000000
## [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38]
## [1,] 1.287916 0 0 0.9273255 0.00000000 0.3002757 0 0 0.3243357
## [2,] 2.004842 0 0 1.4073491 0.01028082 0.2168371 0 0 0.3965515
## [3,] 2.116920 0 0 1.4379776 0.62111875 0.0000000 0 0 0.2748485
## [,39] [,40] [,41] [,42] [,43] [,44] [,45] [,46]
## [1,] 1.069577 0.5426746 0.07592177 0 0.1962220 0 0.2882456 0
## [2,] 1.771252 0.6301419 0.00000000 0 0.8368903 0 0.1269799 0
## [3,] 2.011434 0.3803341 0.00000000 0 1.5894921 0 0.0000000 0
## [,47] [,48] [,49] [,50] [,51] [,52] [,53] [,54]
## [1,] 0.00000000 0.05998498 0 0.4304976 0.6909416 0 0.31230567 0.1941137
## [2,] 0.01475447 0.51791642 0 0.4908824 1.0389730 0 0.30669433 0.1225063
## [3,] 0.43717333 1.11071058 0 0.2802536 1.0484662 0 0.08549794 0.0000000
## [,55] [,56] [,57] [,58] [,59] [,60] [,61] [,62]
## [1,] 0.4365126 0.0000000 0 0.4605727 0 0.05787674 0.2723088 0.05186173
## [2,] 0.5358110 0.0000000 0 0.7155254 0 0.00000000 0.7065781 0.00000000
## [3,] 0.3749289 0.1531475 0 0.7536300 0 0.00000000 1.1215209 0.00000000
## [,63] [,64] [,65] [,66] [,67] [,68] [,69] [,70] [,71]
## [1,] 0.05998498 0 0.0000000 0 0 1.151679 1.536330 0 0.1300569
## [2,] 0.51791642 0 0.0000000 0 0 1.685868 2.463075 0 0.3426756
## [3,] 1.11071058 0 0.1585526 0 0 1.638138 2.695782 0 0.5480640
## [,72] [,73] [,74] [,75] [,76] [,77] [,78] [,79]
## [1,] 0.4365126 0 0 0.0720150 0.0000000 0.3183207 0.8271785 0.6728965
## [2,] 0.5358110 0 0 0.6077736 0.1046117 0.3516229 1.3579468 0.9041871
## [3,] 0.3749289 0 0 1.3000611 0.6265239 0.1801732 1.5272477 0.7644403
## [,80]
## [1,] 0.5607196
## [2,] 0.7649277
## [3,] 0.6643599
##
## $linear_cache$W
## [,1] [,2] [,3]
## [1,] 0.08013768 0.0625536 0.1121582
## [2,] 0.26414429 0.1112459 0.6665622
## [3,] 0.71345435 0.3634207 0.8498899
## [4,] 0.07357475 0.0597748 0.2804513
##
## $linear_cache$b
## [,1]
## [1,] 0
## [2,] 0
## [3,] 0
## [4,] 0
##
##
## $activation_cache
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
## [1,] 0 0 0 0.2390594 0 0 0.1394750 0 0 0.1094459
## [2,] 0 0 0 0.9694026 0 0 0.5366553 0 0 0.4523776
## [3,] 0 0 0 1.7751546 0 0 1.0405052 0 0 0.8112707
## [4,] 0 0 0 0.3922693 0 0 0.2139389 0 0 0.1840082
## [,11] [,12] [,13] [,14] [,15] [,16] [,17] [,18]
## [1,] 0.03780772 0.02779534 0 0.09663827 0.1101787 0.006558136 0 0
## [2,] 0.22469340 0.16518930 0 0.35512861 0.5795617 0.038975370 0 0
## [3,] 0.28649186 0.21062207 0 0.72372196 0.8029530 0.049694945 0 0
## [4,] 0.09453813 0.06950221 0 0.13961364 0.2475918 0.016398606 0 0
## [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27]
## [1,] 0.2390594 0 0 0 0 0.1244604 0 0.1236077 0
## [2,] 0.9694026 0 0 0 0 0.4945164 0 0.6476668 0
## [3,] 1.7751546 0 0 0 0 0.9258879 0 0.8997446 0
## [4,] 0.3922693 0 0 0 0 0.1989735 0 0.2768292 0
## [,28] [,29] [,30] [,31] [,32] [,33] [,34]
## [1,] 0.003192021 0.03762734 0.4660505 0 0 0.3236294 0.07030663
## [2,] 0.010521320 0.10343833 1.9742850 0 0 1.3600110 0.41515799
## [3,] 0.028418109 0.29303607 3.4466194 0 0 2.3951868 0.53161881
## [4,] 0.002930608 0.03505410 0.8082900 0 0 0.5556344 0.17480807
## [,35] [,36] [,37] [,38] [,39] [,40] [,41] [,42]
## [1,] 0.03762734 0 0 0.08162374 0.4221104 0.1255639 0.006084195 0
## [2,] 0.10343833 0 0 0.31298976 1.8203134 0.4669614 0.020054302 0
## [3,] 0.29303607 0 0 0.60910470 3.1163019 0.9394222 0.054166717 0
## [4,] 0.03505410 0 0 0.12464831 0.7486794 0.1842589 0.005585925 0
## [,43] [,44] [,45] [,46] [,47] [,48] [,49] [,50]
## [1,] 0.2463498 0 0.03104239 0 0.0499555 0.1617798 0 0.09663827
## [2,] 1.2044269 0 0.09026443 0 0.2930446 0.8138185 0 0.35512861
## [3,] 1.7950320 0 0.25179723 0 0.3769113 1.1749998 0 0.72372196
## [4,] 0.5102370 0 0.02879780 0 0.1234878 0.3468719 0 0.13961364
## [,51] [,52] [,53] [,54] [,55] [,56] [,57] [,58]
## [1,] 0.2379560 0 0.05380158 0.02321903 0.1105494 0.01717674 0 0.1661937
## [2,] 0.9969576 0 0.17360193 0.06490235 0.4248225 0.10208234 0 0.7035982
## [3,] 1.7616203 0 0.40693874 0.18301259 0.8248049 0.13015851 0 1.2291369
## [4,] 0.4069839 0 0.06528841 0.02160466 0.1692936 0.04295041 0 0.2880134
## [,59] [,60] [,61] [,62] [,63] [,64] [,65] [,66]
## [1,] 0 0.004638108 0.1918089 0.004156079 0.1617798 0 0.01778297 0
## [2,] 0 0.015287811 0.8980962 0.013698981 0.8138185 0 0.10568520 0
## [3,] 0 0.041292413 1.4042343 0.037000978 1.1749998 0 0.13475229 0
## [4,] 0 0.004258267 0.3768026 0.003815714 0.3468719 0 0.04446629 0
## [,67] [,68] [,69] [,70] [,71] [,72] [,73] [,74] [,75]
## [1,] 0 0.3814806 0.5795461 0 0.0933279 0.1105494 0 0 0.1896020
## [2,] 0 1.5836765 2.4767262 0 0.4377938 0.4248225 0 0 0.9532063
## [3,] 0 2.8265872 4.2823518 0 0.6831191 0.8248049 0 0 1.3771657
## [4,] 0 0.6449249 1.0163004 0 0.1837575 0.1692936 0 0 0.4062318
## [,76] [,77] [,78] [,79] [,80]
## [1,] 0.0768136 0.06771266 0.3225259 0.1962227 0.1672971
## [2,] 0.4292548 0.24329585 1.3875661 0.7878759 0.6760431
## [3,] 0.5704944 0.50802172 2.3816525 1.4583714 1.2426711
## [4,] 0.1819626 0.09496836 0.5703490 0.3179440 0.2732988
##
## $linear_cache
## $linear_cache$A
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
## [1,] 0 0 0 0.2390594 0 0 0.1394750 0 0 0.1094459
## [2,] 0 0 0 0.9694026 0 0 0.5366553 0 0 0.4523776
## [3,] 0 0 0 1.7751546 0 0 1.0405052 0 0 0.8112707
## [4,] 0 0 0 0.3922693 0 0 0.2139389 0 0 0.1840082
## [,11] [,12] [,13] [,14] [,15] [,16] [,17] [,18]
## [1,] 0.03780772 0.02779534 0 0.09663827 0.1101787 0.006558136 0 0
## [2,] 0.22469340 0.16518930 0 0.35512861 0.5795617 0.038975370 0 0
## [3,] 0.28649186 0.21062207 0 0.72372196 0.8029530 0.049694945 0 0
## [4,] 0.09453813 0.06950221 0 0.13961364 0.2475918 0.016398606 0 0
## [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27]
## [1,] 0.2390594 0 0 0 0 0.1244604 0 0.1236077 0
## [2,] 0.9694026 0 0 0 0 0.4945164 0 0.6476668 0
## [3,] 1.7751546 0 0 0 0 0.9258879 0 0.8997446 0
## [4,] 0.3922693 0 0 0 0 0.1989735 0 0.2768292 0
## [,28] [,29] [,30] [,31] [,32] [,33] [,34]
## [1,] 0.003192021 0.03762734 0.4660505 0 0 0.3236294 0.07030663
## [2,] 0.010521320 0.10343833 1.9742850 0 0 1.3600110 0.41515799
## [3,] 0.028418109 0.29303607 3.4466194 0 0 2.3951868 0.53161881
## [4,] 0.002930608 0.03505410 0.8082900 0 0 0.5556344 0.17480807
## [,35] [,36] [,37] [,38] [,39] [,40] [,41] [,42]
## [1,] 0.03762734 0 0 0.08162374 0.4221104 0.1255639 0.006084195 0
## [2,] 0.10343833 0 0 0.31298976 1.8203134 0.4669614 0.020054302 0
## [3,] 0.29303607 0 0 0.60910470 3.1163019 0.9394222 0.054166717 0
## [4,] 0.03505410 0 0 0.12464831 0.7486794 0.1842589 0.005585925 0
## [,43] [,44] [,45] [,46] [,47] [,48] [,49] [,50]
## [1,] 0.2463498 0 0.03104239 0 0.0499555 0.1617798 0 0.09663827
## [2,] 1.2044269 0 0.09026443 0 0.2930446 0.8138185 0 0.35512861
## [3,] 1.7950320 0 0.25179723 0 0.3769113 1.1749998 0 0.72372196
## [4,] 0.5102370 0 0.02879780 0 0.1234878 0.3468719 0 0.13961364
## [,51] [,52] [,53] [,54] [,55] [,56] [,57] [,58]
## [1,] 0.2379560 0 0.05380158 0.02321903 0.1105494 0.01717674 0 0.1661937
## [2,] 0.9969576 0 0.17360193 0.06490235 0.4248225 0.10208234 0 0.7035982
## [3,] 1.7616203 0 0.40693874 0.18301259 0.8248049 0.13015851 0 1.2291369
## [4,] 0.4069839 0 0.06528841 0.02160466 0.1692936 0.04295041 0 0.2880134
## [,59] [,60] [,61] [,62] [,63] [,64] [,65] [,66]
## [1,] 0 0.004638108 0.1918089 0.004156079 0.1617798 0 0.01778297 0
## [2,] 0 0.015287811 0.8980962 0.013698981 0.8138185 0 0.10568520 0
## [3,] 0 0.041292413 1.4042343 0.037000978 1.1749998 0 0.13475229 0
## [4,] 0 0.004258267 0.3768026 0.003815714 0.3468719 0 0.04446629 0
## [,67] [,68] [,69] [,70] [,71] [,72] [,73] [,74] [,75]
## [1,] 0 0.3814806 0.5795461 0 0.0933279 0.1105494 0 0 0.1896020
## [2,] 0 1.5836765 2.4767262 0 0.4377938 0.4248225 0 0 0.9532063
## [3,] 0 2.8265872 4.2823518 0 0.6831191 0.8248049 0 0 1.3771657
## [4,] 0 0.6449249 1.0163004 0 0.1837575 0.1692936 0 0 0.4062318
## [,76] [,77] [,78] [,79] [,80]
## [1,] 0.0768136 0.06771266 0.3225259 0.1962227 0.1672971
## [2,] 0.4292548 0.24329585 1.3875661 0.7878759 0.6760431
## [3,] 0.5704944 0.50802172 2.3816525 1.4583714 1.2426711
## [4,] 0.1819626 0.09496836 0.5703490 0.3179440 0.2732988
##
## $linear_cache$W
## [,1] [,2] [,3] [,4]
## [1,] 0.2740802 0.632204 0.7860847 0.1559307
##
## $linear_cache$b
## [,1]
## [1,] 0
##
##
## $activation_cache
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
## [1,] 0.5 0.5 0.5 0.8942559 0.5 0.5 0.7736063 0.5 0.5 0.7275865 0.5968513
## [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19] [,20]
## [1,] 0.571618 0.5 0.6988183 0.7439108 0.5170083 0.5 0.5 0.8942559 0.5
## [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28] [,29]
## [1,] 0.5 0.5 0.5 0.7513079 0.5 0.767416 0.5 0.50758 0.5772573
## [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38]
## [1,] 0.9853889 0.5 0.5 0.9487274 0.6741223 0.5772573 0.5 0.5 0.6722743
## [,39] [,40] [,41] [,42] [,43] [,44] [,45] [,46]
## [1,] 0.9788123 0.749665 0.5144451 0.5 0.9104851 0.5 0.5666014 0.5
## [,47] [,48] [,49] [,50] [,51] [,52] [,53] [,54]
## [1,] 0.6258597 0.8229666 0.5 0.6988183 0.8950827 0.5 0.6117244 0.5485039
## [,55] [,56] [,57] [,58] [,59] [,60] [,61] [,62]
## [1,] 0.7258441 0.5444468 0.5 0.8177957 0.5 0.5110131 0.8560612 0.5098689
## [,63] [,64] [,65] [,66] [,67] [,68] [,69] [,70] [,71]
## [1,] 0.8229666 0.5 0.5460067 0.5 0.5 0.9685763 0.9947768 0.5 0.7043288
## [,72] [,73] [,74] [,75] [,76] [,77] [,78] [,79]
## [1,] 0.7258441 0.5 0.5 0.8582098 0.6833608 0.6425681 0.9491525 0.8516827
## [,80]
## [1,] 0.8164848
computeCost <- function(AL, y) {
m <- dim(y)[2]
logprobs <- (log(AL) * y) + (log(1-AL) * (1-y))
cost <- -sum(logprobs/m)
return (cost)
}
## [1] 1.042769
Backprop is used to calculate the gradient of the loss function wrt the parameters.
Define the relu_backward function.
Define the sigmoid_backward function.
sigmoid_backward <- function(dA, cache) {
Z <- cache
s <- 1 / (1 + exp(-Z))
dZ <- as.matrix(dA * s * (1-s))
return (dZ)
}
For layer , the linear part is: (followed by an activation).
Suppose you have already calculated the derivative . You want to get .
The three outputs are computed using the input .Here are the formulas you need:
linearBackward <- function(dZ, cache) {
A_prev <- cache$A
W <- cache$W
b <- cache$b
m <- dim(A_prev)[2]
dW <- 1/m * (dZ %*% t(A_prev))
db <- matrix(1/m * sum(dZ), nrow = dim(dW)[1])
dA_prev <- t(W) %*% dZ
out <- list("dA_prev" = dA_prev,
"dW" = dW,
"db" = db)
return (out)
}
dZ <- cost-y_train
linear_backward3 <- linearBackward(dZ, L_model_forward$caches[5]$linear_cache)
linear_backward2 <- linearBackward(linear_backward3$dA_prev, L_model_forward$caches[3]$linear_cache)
linear_backward1 <- linearBackward(linear_backward2$dA_prev, L_model_forward$caches[1]$linear_cache)
## dim dZ3:
## 1 80
##
## Linear Backward 3:
##
## $dA_prev
## [1] 4 80
##
## $dW
## [1] 1 4
##
## $db
## [1] 1 1
##
## Linear Backward 2:
## $dA_prev
## [1] 3 80
##
## $dW
## [1] 4 3
##
## $db
## [1] 4 1
##
## Linear Backward 1:
## $dA_prev
## [1] 2 80
##
## $dW
## [1] 3 2
##
## $db
## [1] 3 1
linearActivationBackward <- function(dA, linear_cache, activation_cache, activation) {
if (activation == "relu") {
dZ <- relu_backward(dA, activation_cache)
} else if (activation == "sigmoid") {
dZ <- sigmoid_backward(dA, activation_cache)
}
dA_prev <- linearBackward(dZ, linear_cache)$dA_prev
dW <- linearBackward(dZ, linear_cache)$dW
db <- linearBackward(dZ, linear_cache)$db
out <- list("dA_prev" = dA_prev,
"dW" = dW,
"db" = db)
return (out)
}
dA <- apply(X = dZ, MARGIN = c(1, 2), function(x) sigmoid(x)$A)
linear_activation_backward3 <- linearActivationBackward(dA = dA,
linear_cache = L_model_forward$caches[5]$linear_cache,
activation_cache = L_model_forward$caches[6]$activation_cache,
activation = "sigmoid")
linear_activation_backward2 <- linearActivationBackward(dA = linear_activation_backward3$dA_prev,
linear_cache = L_model_forward$caches[3]$linear_cache,
activation_cache = L_model_forward$caches[4]$activation_cache,
activation = "relu")
linear_activation_backward1 <- linearActivationBackward(dA = linear_activation_backward2$dA_prev,
linear_cache = L_model_forward$caches[1]$linear_cache,
activation_cache = L_model_forward$caches[2]$activation_cache,
activation = "relu")
## dim dA3:
## 1 80
##
## Linear Backward 3:
##
## $dA_prev
## [1] 4 80
##
## $dW
## [1] 1 4
##
## $db
## [1] 1 1
##
## Linear Backward 2:
##
## $dA_prev
## [1] 3 80
##
## $dW
## [1] 4 3
##
## $db
## [1] 4 1
##
## Linear Backward 1:
##
## $dA_prev
## [1] 2 80
##
## $dW
## [1] 3 2
##
## $db
## [1] 3 1
In L_model_backward function, we will iterate through all the hidden layers backward, starting from layer . On each step, we will use the cached values for layer to backpropagate through layer .
Here, we will now implement backpropagation for the [LINEAR->RELU] (L-1) -> LINEAR -> SIGMOID model.
LModelBackward <- function(AL, Y, caches) {
L <- length(caches) - 1
grads <- list()
m <- dim(AL)[2]
dA_L <- -((Y/AL) - ((1 - Y)/(1 - AL)))
current_linear_cache <- caches[L]$linear_cache
current_activation_cache <- caches[L+1]$activation_cache
linear_activation_backward_L <- linearActivationBackward(dA = dA_L,
linear_cache = current_linear_cache,
activation_cache = current_activation_cache,
activation = "sigmoid")
dAL <- paste("dA", floor(L/2)+1, sep="")
dWL <- paste("dW", floor(L/2)+1, sep="")
dbL <- paste("db", floor(L/2)+1, sep="")
grads[[dAL]] <- linear_activation_backward_L$dA_prev
grads[[dWL]] <- linear_activation_backward_L$dW
grads[[dbL]] <- linear_activation_backward_L$db
for (i in seq(L-2, 1, -2)) {
current_linear_cache <- caches[i]$linear_cache
current_activation_cache <- caches[i+1]$activation_cache
linear_activation_backward <- linearActivationBackward(dA = grads[[dAL]],
linear_cache = current_linear_cache,
activation_cache = current_activation_cache,
activation = "relu")
dA_prev_temp <- linear_activation_backward$dA_prev
dW_temp <- linear_activation_backward$dW
db_temp <- linear_activation_backward$db
dAL <- paste("dA", floor(i/2)+1, sep="")
dWL <- paste("dW", floor(i/2)+1, sep="")
dbL <- paste("db", floor(i/2)+1, sep="")
grads[[dAL]] <- dA_prev_temp
grads[[dWL]] <- dW_temp
grads[[dbL]] <- db_temp
}
return(grads)
}
L_model_backward<- LModelBackward(L_model_forward$AL, y_train, L_model_forward$caches)
L_model_backward
## $dA3
## [,1] [,2] [,3] [,4] [,5] [,6]
## [1,] -0.12881975 0.12881975 -0.12881975 0.5339293 -0.12881975 -0.12881975
## [2,] -0.29714058 0.29714058 -0.29714058 1.2315819 -0.29714058 -0.29714058
## [3,] -0.36946567 0.36946567 -0.36946567 1.5313533 -0.36946567 -0.36946567
## [4,] -0.07328859 0.07328859 -0.07328859 0.3037650 -0.07328859 -0.07328859
## [,7] [,8] [,9] [,10] [,11] [,12]
## [1,] 0.2615377 -0.12881975 0.12881975 0.2209719 -0.10515623 -0.11058810
## [2,] 0.6032728 -0.29714058 0.29714058 0.5097022 -0.24255741 -0.25508676
## [3,] 0.7501116 -0.36946567 0.36946567 0.6337656 -0.30159676 -0.31717580
## [4,] 0.1487949 -0.07328859 0.07328859 0.1257161 -0.05982586 -0.06291618
## [,13] [,14] [,15] [,16] [,17] [,18]
## [1,] -0.12881975 0.2018425 -0.08045412 -0.12405560 -0.12881975 -0.12881975
## [2,] -0.29714058 0.4655776 -0.18557856 -0.28615140 -0.29714058 -0.29714058
## [3,] -0.36946567 0.5789009 -0.23074905 -0.35580169 -0.36946567 -0.36946567
## [4,] -0.07328859 0.1148330 -0.04577224 -0.07057815 -0.07328859 -0.07328859
## [,19] [,20] [,21] [,22] [,23] [,24]
## [1,] 0.5339293 -0.12881975 -0.12881975 -0.12881975 0.12881975 0.2400265
## [2,] 1.2315819 -0.29714058 -0.29714058 -0.29714058 0.29714058 0.5536544
## [3,] 1.5313533 -0.36946567 -0.36946567 -0.36946567 0.36946567 0.6884159
## [4,] 0.3037650 -0.07328859 -0.07328859 -0.07328859 0.07328859 0.1365568
## [,25] [,26] [,27] [,28] [,29] [,30]
## [1,] -0.12881975 -0.07733124 -0.12881975 0.13055836 0.14929783 3.712960
## [2,] -0.29714058 -0.17837521 -0.29714058 0.30115092 0.34437609 8.564455
## [3,] -0.36946567 -0.22179238 -0.36946567 0.37445214 0.42819848 10.649075
## [4,] -0.07328859 -0.04399557 -0.07328859 0.07427773 0.08493905 2.112390
## [,31] [,32] [,33] [,34] [,35] [,36]
## [1,] 0.12881975 -0.12881975 1.0756381 -0.09091711 0.14929783 0.12881975
## [2,] 0.29714058 -0.29714058 2.4811082 -0.20971289 0.34437609 0.29714058
## [3,] 0.36946567 -0.36946567 3.0850189 -0.26075777 0.42819848 0.36946567
## [4,] 0.07328859 -0.07328859 0.6119559 -0.05172489 0.08493905 0.07328859
## [,37] [,38] [,39] [,40] [,41] [,42] [,43]
## [1,] -0.12881975 0.1871260 2.568138 0.2385918 0.1321772 0.12881975 -0.06158663
## [2,] -0.29714058 0.4316321 5.923766 0.5503451 0.3048849 0.29714058 -0.14205808
## [3,] -0.36946567 0.5366929 7.365633 0.6843010 0.3790950 0.36946567 -0.17663553
## [4,] -0.07328859 0.1064604 1.461074 0.1357405 0.0751987 0.07328859 -0.03503808
## [,44] [,45] [,46] [,47] [,48] [,49]
## [1,] -0.12881975 0.14605975 -0.12881975 -0.09942346 -0.07061345 -0.12881975
## [2,] -0.29714058 0.33690702 -0.29714058 -0.22933398 -0.16287968 -0.29714058
## [3,] -0.36946567 0.41891141 -0.36946567 -0.28515469 -0.20252519 -0.36946567
## [4,] -0.07328859 0.08309683 -0.07328859 -0.05656435 -0.04017365 -0.07328859
## [,50] [,51] [,52] [,53] [,54] [,55] [,56]
## [1,] 0.2018425 0.5379501 0.12881975 0.16094104 0.14089646 0.2197007 -0.11696838
## [2,] 0.4655776 1.2408563 0.29714058 0.37123278 0.32499719 0.5067702 -0.26980377
## [3,] 0.5789009 1.5428852 0.36946567 0.46159218 0.40410268 0.6301198 -0.33547498
## [4,] 0.1148330 0.3060525 0.07328859 0.09156315 0.08015932 0.1249929 -0.06654607
## [,57] [,58] [,59] [,60] [,61] [,62]
## [1,] -0.12881975 0.3195830 0.12881975 0.13136251 -0.06699816 0.13109330
## [2,] -0.29714058 0.7371625 0.29714058 0.30300581 -0.15454054 0.30238484
## [3,] -0.36946567 0.9165905 0.36946567 0.37675851 -0.19215626 0.37598640
## [4,] -0.07328859 0.1818183 0.07328859 0.07473523 -0.03811683 0.07458207
## [,63] [,64] [,65] [,66] [,67] [,68]
## [1,] -0.07061345 -0.12881975 -0.11658580 -0.12881975 -0.12881975 1.7396089
## [2,] -0.16287968 -0.29714058 -0.26892129 -0.29714058 -0.29714058 4.0126487
## [3,] -0.20252519 -0.36946567 -0.33437770 -0.36946567 -0.36946567 4.9893419
## [4,] -0.04017365 -0.07328859 -0.06632841 -0.07328859 -0.07328859 0.9897045
## [,69] [,70] [,71] [,72] [,73] [,74]
## [1,] 10.341757 -0.12881975 -0.08615074 0.2197007 0.12881975 -0.12881975
## [2,] 23.854693 -0.29714058 -0.19871860 0.5067702 0.29714058 -0.29714058
## [3,] 29.661012 -0.36946567 -0.24708742 0.6301198 0.36946567 -0.36946567
## [4,] 5.883669 -0.07328859 -0.04901319 0.1249929 0.07328859 -0.07328859
## [,75] [,76] [,77] [,78] [,79] [,80]
## [1,] -0.06677242 -0.08941752 0.17319913 1.0844270 0.3873839 0.3174612
## [2,] -0.15401983 -0.20625387 0.39950776 2.5013808 0.8935547 0.7322683
## [3,] -0.19150881 -0.25645681 0.49674939 3.1102260 1.1110491 0.9105050
## [4,] -0.03798840 -0.05087173 0.09853707 0.6169561 0.2203919 0.1806112
##
## $dW3
## [,1] [,2] [,3] [,4]
## Species 0.496205 2.099032 3.669874 0.8591043
##
## $db3
## [,1]
## [1,] 1.059136
##
## $dA2
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
## [1,] 0 0 0 1.4830033 0 0 0.726428 0 0 0.6137554 -0.2920743
## [2,] 0 0 0 0.7450906 0 0 0.364972 0 0 0.3083630 -0.1467440
## [3,] 0 0 0 2.2674835 0 0 1.110694 0 0 0.9384202 -0.4465760
## [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19]
## [1,] -0.3071615 0 0.5606231 -0.2234635 -0.3445678 0 0 1.4830033
## [2,] -0.1543241 0 0.2816682 -0.1122725 -0.1731178 0 0 0.7450906
## [3,] -0.4696440 0 0.8571818 -0.3416714 -0.5268376 0 0 2.2674835
## [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28]
## [1,] 0 0 0 0 0.6666803 0 -0.2147896 0 0.3626294
## [2,] 0 0 0 0 0.3349535 0 -0.1079146 0 0.1821923
## [3,] 0 0 0 0 1.0193413 0 -0.3284092 0 0.5544534
## [,29] [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37]
## [1,] 0.4146788 10.312847 0 0 2.987614 -0.2525248 0.4146788 0 0
## [2,] 0.2083429 5.181381 0 0 1.501037 -0.1268735 0.2083429 0 0
## [3,] 0.6340359 15.768145 0 0 4.568005 -0.3861055 0.6340359 0 0
## [,38] [,39] [,40] [,41] [,42] [,43] [,44] [,45]
## [1,] 0.5197476 7.133074 0.6626953 0.3671257 0 -0.17105854 0 0.4056850
## [2,] 0.2611316 3.583799 0.3329514 0.1844513 0 -0.08594324 0 0.2038242
## [3,] 0.7946841 10.906333 1.0132484 0.5613281 0 -0.26154522 0 0.6202845
## [,46] [,47] [,48] [,49] [,50] [,51] [,52] [,53]
## [1,] 0 -0.2761514 -0.19613078 0 0.5606231 1.4941710 0 0.4470181
## [2,] 0 -0.1387440 -0.09854003 0 0.2816682 0.7507015 0 0.2245909
## [3,] 0 -0.4222302 -0.29988018 0 0.8571818 2.2845587 0 0.6834821
## [,54] [,55] [,56] [,57] [,58] [,59] [,60] [,61]
## [1,] 0.3913438 0.6102248 -0.3248829 0 0.8876507 0 0.3648630 -0.18608923
## [2,] 0.1966190 0.3065892 -0.1632277 0 0.4459735 0 0.1833145 -0.09349496
## [3,] 0.5983571 0.9330220 -0.4967397 0 1.3572007 0 0.5578684 -0.28452685
## [,62] [,63] [,64] [,65] [,66] [,67] [,68] [,69] [,70]
## [1,] 0.3641152 -0.19613078 0 -0.3238203 0 0 4.831811 28.72451 0
## [2,] 0.1829388 -0.09854003 0 -0.1626938 0 0 2.427599 14.43177 0
## [3,] 0.5567252 -0.29988018 0 -0.4951149 0 0 7.387747 43.91923 0
## [,71] [,72] [,73] [,74] [,75] [,76] [,77] [,78]
## [1,] -0.2392860 0.6102248 0 0 -0.18546223 -0.2483596 0.4810653 3.012026
## [2,] -0.1202221 0.3065892 0 0 -0.09317994 -0.1247808 0.2416969 1.513302
## [3,] -0.3658637 0.9330220 0 0 -0.28356817 -0.3797370 0.7355396 4.605329
## [,79] [,80]
## [1,] 1.0759695 0.8817573
## [2,] 0.5405886 0.4430125
## [3,] 1.6451366 1.3481899
##
## $dW2
## [,1] [,2] [,3]
## [1,] 0.3779573 0.5864597 0.6154363
## [2,] 0.8718107 1.3527504 1.4195891
## [3,] 1.0840126 1.6820147 1.7651222
## [4,] 0.2150288 0.3336507 0.3501362
##
## $db2
## [,1]
## [1,] 2.120485
## [2,] 2.120485
## [3,] 2.120485
## [4,] 2.120485
##
## $dA1
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11]
## [1,] 0 0 0 1.670193 0 0 0.8181204 0 0 0.6912259 -0.2746920
## [2,] 0 0 0 3.638665 0 0 1.7823481 0 0 1.5058971 -0.4196493
## [,12] [,13] [,14] [,15] [,16] [,17] [,18] [,19]
## [1,] -0.2888813 0 0.6313869 -0.2429370 -0.3240614 0 0 1.670193
## [2,] -0.4413264 0 1.3755327 -0.4208483 -0.4950714 0 0 3.638665
## [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27] [,28]
## [1,] 0 0 0 0 0.7508311 0 -0.2335073 0 0.01417142
## [2,] 0 0 0 0 1.6357523 0 -0.4045128 0 0.20680063
## [,29] [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37]
## [1,] 0.07702119 11.61457 0 0 3.364722 -0.2745308 0.07702119 0 0
## [2,] 0.42164086 25.30338 0 0 7.330346 -0.4755793 0.42164086 0 0
## [,38] [,39] [,40] [,41] [,42] [,43] [,44] [,45]
## [1,] 0.5853521 8.033437 0.7463431 0.01434713 0 -0.1926502 0 0.0753507
## [2,] 1.2752416 17.501558 1.6259750 0.20936475 0 -0.4197056 0 0.4124960
## [,46] [,47] [,48] [,49] [,50] [,51] [,52] [,53]
## [1,] 0 -0.3002163 -0.2208871 0 0.6313869 1.682771 0 0.5034424
## [2,] 0 -0.5200752 -0.4812223 0 1.3755327 3.666066 0 1.0967941
## [,54] [,55] [,56] [,57] [,58] [,59] [,60]
## [1,] 0.07268702 0.6872496 -0.3055480 0 0.9996931 0 0.01425871
## [2,] 0.39791407 1.4972345 -0.4667883 0 2.1779205 0 0.20807439
## [,61] [,62] [,63] [,64] [,65] [,66] [,67] [,68]
## [1,] -0.2095781 0.01422949 -0.2208871 0 -0.3045486 0 0 5.44170
## [2,] -0.4565845 0.20764797 -0.4812223 0 -0.4652615 0 0 11.85523
## [,69] [,70] [,71] [,72] [,73] [,74] [,75] [,76]
## [1,] 32.35023 0 -0.2694896 0.6872496 0 0 -0.2088719 -0.2700027
## [2,] 70.47785 0 -0.5871071 1.4972345 0 0 -0.4550461 -0.4677350
## [,77] [,78] [,79] [,80]
## [1,] 0.5417871 3.392214 1.211782 0.9930559
## [2,] 1.1803315 7.390241 2.639976 2.1634607
##
## $dW1
## Sepal.Length Sepal.Width
## [1,] -0.03778349 1.8434129
## [2,] -0.02028630 0.9275388
## [3,] -0.05709898 2.8187046
##
## $db1
## [,1]
## [1,] 2.624512
## [2,] 2.624512
## [3,] 2.624512
updateParameters <- function(parameters, grads, learning_rate){
L = floor(length(parameters)/2)
for (i in 1:L) {
W <- paste("W", i, sep="")
dW <- paste("dW", i, sep="")
b <- paste("b", i, sep="")
db <- paste("db", i, sep="")
parameters[[W]] <- parameters[[W]] - learning_rate * grads[[dW]]
parameters[[b]] <- parameters[[b]] - learning_rate * grads[[db]]
}
return (parameters)
}
update_params <- updateParameters(parameters = params, grads = L_model_backward, learning_rate = 0.1)
update_params
## $W1
## Sepal.Length Sepal.Width
## [1,] 0.04285797 0.3859396
## [2,] 0.29393056 0.7959609
## [3,] 0.62081683 0.6578336
##
## $b1
## [,1]
## [1,] -0.2624512
## [2,] -0.2624512
## [3,] -0.2624512
##
## $W2
## [,1] [,2] [,3]
## [1,] 0.04234196 0.003907636 0.05061452
## [2,] 0.17696322 -0.024029181 0.52460332
## [3,] 0.60505309 0.195219223 0.67337768
## [4,] 0.05207187 0.026409725 0.24543763
##
## $b2
## [,1]
## [1,] -0.2120485
## [2,] -0.2120485
## [3,] -0.2120485
## [4,] -0.2120485
##
## $W3
## [,1] [,2] [,3] [,4]
## Species 0.2244597 0.4223009 0.4190973 0.07002028
##
## $b3
## [,1]
## [1,] -0.1059136
trainModel <- function(X, y, layer_size_list, iterations, learning_rate = 0.01) {
init_params <- initializeParameters(layer_size_list)
cost_history <- c()
for (i in 1:iterations) {
forward_pass <- LModelForward(X, init_params)
cost <- computeCost(forward_pass$AL, y)
backward_pass <- LModelBackward(forward_pass$AL, y, forward_pass$caches)
update_parameters <- updateParameters(init_params, backward_pass, learning_rate = learning_rate)
init_params <- update_parameters
cost_history <- c(cost_history, cost)
cat("Iteration", i, " | Cost: ", cost_history[i], "\n")
}
model_out <- list("updated_params" = update_parameters,
"cost_hist" = cost_history)
return (model_out)
}
## Iteration 1 | Cost: 1.165654
## Iteration 2 | Cost: 0.856832
## Iteration 3 | Cost: 0.844496
## Iteration 4 | Cost: 0.8342427
## Iteration 5 | Cost: 0.8257285
## Iteration 6 | Cost: 0.8182522
## Iteration 7 | Cost: 0.8116161
## Iteration 8 | Cost: 0.8057899
## Iteration 9 | Cost: 0.8005528
## Iteration 10 | Cost: 0.7957285
makePrediction <- function(X, layer_size_list){
init_params <- initializeParameters(layer_size_list)
forward_pass <- LModelForward(X, init_params)
pred <- forward_pass$AL
return (pred)
}
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14]
## [1,] 1 1 1 1 1 0 0 0 1 1 0 1 1 0
## [,15] [,16] [,17] [,18] [,19] [,20]
## [1,] 1 1 1 0 1 0
## [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
## Species 0 0 0 0 1 1 1 0 0 0 0 1 0
## [,14] [,15] [,16] [,17] [,18] [,19] [,20]
## Species 0 1 1 1 1 1 1
## y_pred
## y_test 0 1
## 0 3 7
## 1 4 6
## We are getting an accuracy of 45 %.